This document assesses the ability of several models to classify StraightenedWorm data. Manual annotations are used to train, validate, and select the best model and hyperparameters
First load the data:
library(tidymodels)
library(tidyverse)
library(here)
files <- tibble(path = list.files(path = here(),
pattern = '.*_manual.csv',
recursive = TRUE))
get_data <- function(...) {
df <- tibble(...)
data <- read_csv(here(df$path))
}
annotations <- files %>%
pmap_dfr(get_data) %>%
janitor::clean_names() %>%
filter(!is.na(worm)) %>%
select(worm, contains('area')) %>%
mutate(worm = case_when(
worm == 'Y' ~ 'Single worm',
worm == 'N' ~ 'Debris',
worm == 'P' ~ 'Partial worm',
worm == 'M' ~ 'Multiple worms',
)) %>%
mutate(worm = as.factor(worm))
glimpse(annotations)
## Rows: 4,821
## Columns: 25
## $ worm <fct> Debris, Debris, Single worm, Single …
## $ area_shape_area <dbl> 5205, 2031, 2469, 2030, 2497, 2152, …
## $ area_shape_bounding_box_area <dbl> 6069, 2394, 2919, 2394, 2919, 2541, …
## $ area_shape_bounding_box_maximum_x <dbl> 21, 42, 63, 84, 105, 126, 147, 168, …
## $ area_shape_bounding_box_maximum_y <dbl> 296, 121, 146, 121, 146, 128, 244, 1…
## $ area_shape_bounding_box_minimum_x <dbl> 0, 21, 42, 63, 84, 105, 126, 147, 16…
## $ area_shape_bounding_box_minimum_y <dbl> 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, …
## $ area_shape_center_x <dbl> 10.13698, 31.13639, 51.49494, 72.997…
## $ area_shape_center_y <dbl> 150.06724, 63.26686, 75.94208, 63.09…
## $ area_shape_compactness <dbl> 6.075054, 2.603662, 3.150583, 2.5034…
## $ area_shape_eccentricity <dbl> 0.9973402, 0.9832310, 0.9884558, 0.9…
## $ area_shape_equivalent_diameter <dbl> 81.40769, 50.85223, 56.06807, 50.839…
## $ area_shape_euler_number <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, …
## $ area_shape_extent <dbl> 0.8576372, 0.8483709, 0.8458376, 0.8…
## $ area_shape_form_factor <dbl> 0.1646076, 0.3840745, 0.3174016, 0.3…
## $ area_shape_major_axis_length <dbl> 302.82179, 119.52235, 144.60349, 117…
## $ area_shape_max_feret_diameter <dbl> 288.00174, 113.03982, 138.00362, 113…
## $ area_shape_maximum_radius <dbl> 11.00000, 11.00000, 10.19804, 11.000…
## $ area_shape_mean_radius <dbl> 5.031499, 4.889537, 4.976384, 5.0222…
## $ area_shape_median_radius <dbl> 5.000000, 5.000000, 5.000000, 5.0000…
## $ area_shape_min_feret_diameter <dbl> 20, 20, 20, 20, 20, 20, 20, 20, 20, …
## $ area_shape_minor_axis_length <dbl> 22.07184, 21.79666, 21.90884, 22.154…
## $ area_shape_orientation <dbl> -0.066591089, -0.161518210, 0.040796…
## $ area_shape_perimeter <dbl> 630.3625, 257.7817, 312.6518, 252.71…
## $ area_shape_solidity <dbl> 0.9553965, 0.9562147, 0.9614486, 0.9…
The Worm Toolbox in Cell Profiler can export a variety of features, some of which may be useful in classification.
library(ggbeeswarm)
annotations %>%
pivot_longer(-worm, names_to = 'measurement', values_to = 'value') %>%
ggplot() +
geom_quasirandom(aes(x = worm, y = value, color = worm)) +
facet_wrap(facets = vars(measurement), scales = 'free_y') +
theme_minimal() +
NULL
First create training (with cross-fold validation) and test data sets
model_data <- annotations %>%
mutate(worm = factor(worm))
set.seed(123)
# data_boot <- bootstraps(model_data, times = 2) # only 2 bootstraps for testing
data_split <- initial_split(model_data,
strata = worm)
train_data <- training(data_split)
test_data <- testing(data_split)
set.seed(234)
folds <- vfold_cv(train_data,
v = 10,
strata = worm)
Specify the models (multinomial regression, decision tree, and random forest):
decision_tree_rpart_spec <-
decision_tree(tree_depth = tune(), min_n = tune(), cost_complexity = tune()) %>%
set_engine('rpart') %>%
set_mode('classification')
multinom_reg_glmnet_spec <-
multinom_reg(penalty = tune(), mixture = tune()) %>%
set_engine('glmnet')
cores <- parallel::detectCores()
rand_forest_ranger_spec <-
rand_forest(mtry = tune(), min_n = tune()) %>%
set_engine('ranger', num.threads = cores) %>%
set_mode('classification')
svm_poly_kernlab_spec <-
svm_poly(cost = tune(), degree = tune(), scale_factor = tune(), margin = tune()) %>%
set_engine('kernlab') %>%
set_mode('classification')
boost_tree_xgboost_spec <-
boost_tree(tree_depth = tune(), learn_rate = tune(),
min_n = tune(), loss_reduction = tune(), mtry = tune(),
sample_size = tune(), stop_iter = tune()) %>%
set_engine('xgboost') %>%
set_mode('classification')
Build the recipe and workflow:
library(themis)
recipe <-
recipe(worm ~ ., data = model_data) %>%
step_nzv(all_predictors()) %>%
step_normalize(all_predictors()) %>%
step_corr(all_numeric_predictors(), threshold = .5) %>%
step_smote(worm)
prep <- prep(recipe)
juice <- juice(prep)
prep
## Recipe
##
## Inputs:
##
## role #variables
## outcome 1
## predictor 24
##
## Training data contained 4821 data points and no missing data.
##
## Operations:
##
## Sparse, unbalanced variable filter removed area_shape_euler_number, area_shape_max... [trained]
## Centering and scaling for area_shape_area, area_shape_bounding_box_area, ... [trained]
## Correlation filter removed area_shape_bounding_box_maximum_y, area_s... [trained]
## SMOTE based on worm [trained]
glimpse(juice)
## Rows: 13,132
## Columns: 6
## $ area_shape_bounding_box_minimum_x <dbl> -1.2442535, -1.1710140, -1.0977744, …
## $ area_shape_bounding_box_minimum_y <dbl> 0.5178302, 0.5178302, 0.5178302, 0.5…
## $ area_shape_eccentricity <dbl> 0.30806485, 0.13513151, 0.19917063, …
## $ area_shape_minor_axis_length <dbl> -0.7834696, -1.1095671, -0.9766285, …
## $ area_shape_orientation <dbl> -0.076386694, -0.113077317, -0.03487…
## $ worm <fct> Debris, Debris, Single worm, Single …
recipe2 <- recipe
recipe2$steps[[3]] <- update(recipe2$steps[[3]], skip = TRUE)
dt_workflow <-
workflow() %>%
add_model(decision_tree_rpart_spec) %>%
add_recipe(recipe2)
mn_workflow <-
workflow() %>%
add_model(multinom_reg_glmnet_spec) %>%
add_recipe(recipe)
rf_workflow <-
workflow() %>%
add_model(rand_forest_ranger_spec) %>%
add_recipe(recipe2)
svm_poly_workflow <-
workflow() %>%
add_model(svm_poly_kernlab_spec) %>%
add_recipe(recipe)
xg_workflow <-
workflow() %>%
add_model(boost_tree_xgboost_spec) %>%
add_recipe(recipe)
dt_grid <- grid_regular(cost_complexity(),
tree_depth(),
min_n(),
levels = 5)
# tune on the train data
dt_tune <-
dt_workflow %>%
tune_grid(
resamples = folds,
grid = dt_grid,
control = control_grid(save_pred = TRUE,
verbose = TRUE),
metrics = metric_set(roc_auc, sens)
)
write_rds(dt_tune, here('code', 'rds', 'dt_tune.rds'))
dt_tune <- read_rds(here('code', 'rds', 'dt_tune.rds'))
# extract the best decision tree
best_tree <- dt_tune %>%
select_best("roc_auc")
# print metrics
(dt_metrics <- dt_tune %>%
collect_metrics() %>%
semi_join(best_tree) %>%
select(.metric:.config) %>%
mutate(model = 'Decision tree'))
## # A tibble: 2 × 7
## .metric .estimator mean n std_err .config model
## <chr> <chr> <dbl> <int> <dbl> <chr> <chr>
## 1 roc_auc hand_till 0.813 10 0.0113 Preprocessor1_Model114 Decision tree
## 2 sens macro 0.586 10 0.0168 Preprocessor1_Model114 Decision tree
# finalize the wf with the best tree
dt_workflow <-
dt_workflow %>%
finalize_workflow(best_tree)
# generate predictions on the hold-out test data
dt_auc <-
dt_tune %>%
collect_predictions(parameters = best_tree) %>%
roc_curve(.pred_Debris:`.pred_Single worm`, truth = worm) %>%
mutate(model = "Decision tree")
dt_auc %>%
autoplot()
mn_grid <- grid_regular(mixture(),
penalty())
mn_tune <-
mn_workflow %>%
tune_grid(
resamples = folds,
grid = mn_grid,
control = control_grid(save_pred = TRUE,
verbose = TRUE),
metrics = metric_set(roc_auc, sens))
write_rds(mn_tune, here('code', 'rds', 'mn_tune.rds'))
mn_tune <- read_rds(here('code', 'rds', 'mn_tune.rds'))
# extract the best model
best_mn <- mn_tune %>%
select_best("roc_auc")
# print metrics
(mn_metrics <- mn_tune %>%
collect_metrics() %>%
semi_join(best_mn) %>%
select(.metric:.config) %>%
mutate(model = 'Multinomial regression'))
## # A tibble: 2 × 7
## .metric .estimator mean n std_err .config model
## <chr> <chr> <dbl> <int> <dbl> <chr> <chr>
## 1 roc_auc hand_till 0.790 10 0.00847 Preprocessor1_Model7 Multinomial regre…
## 2 sens macro 0.537 10 0.0174 Preprocessor1_Model7 Multinomial regre…
# finalize the wf with the best model
mn_workflow <-
mn_workflow %>%
finalize_workflow(best_mn)
# generate predictions on the hold-out test data
mn_auc <-
mn_tune %>%
collect_predictions(parameters = best_mn) %>%
roc_curve(.pred_Debris:`.pred_Single worm`, truth = worm) %>%
mutate(model = "Multinomial regression")
mn_auc %>%
autoplot()
rf_grid <- grid_regular(finalize(mtry(), model_data),
min_n())
rf_tune <-
rf_workflow %>%
tune_grid(
resamples = folds,
grid = rf_grid,
control = control_grid(save_pred = TRUE,
verbose = TRUE),
metrics = metric_set(roc_auc, sens))
write_rds(rf_tune, here('code', 'rds', 'rf_tune.rds'))
rf_tune <- read_rds(here('code', 'rds', 'rf_tune.rds'))
# extract the best decision model
best_rf <- rf_tune %>%
select_best("roc_auc")
# print metrics
(rf_metrics <- rf_tune %>%
collect_metrics() %>%
semi_join(best_rf) %>%
select(.metric:.config) %>%
mutate(model = 'Random forest'))
## # A tibble: 2 × 7
## .metric .estimator mean n std_err .config model
## <chr> <chr> <dbl> <int> <dbl> <chr> <chr>
## 1 roc_auc hand_till 0.845 10 0.0110 Preprocessor1_Model7 Random forest
## 2 sens macro 0.605 10 0.0170 Preprocessor1_Model7 Random forest
# finalize the wf with the best model
rf_workflow <-
rf_workflow %>%
finalize_workflow(best_rf)
# generate predictions on the hold-out test data
rf_auc <-
rf_tune %>%
collect_predictions(parameters = best_rf) %>%
roc_curve(.pred_Debris:`.pred_Single worm`, truth = worm) %>%
mutate(model = "Random forest")
rf_auc %>%
autoplot()
svm_grid <- grid_regular(cost(),
degree(),
scale_factor(),
svm_margin())
svm_tune <-
svm_poly_workflow %>%
tune_grid(
resamples = folds,
grid = svm_grid,
control = control_grid(save_pred = TRUE,
verbose = TRUE),
metrics = metric_set(roc_auc, sens))
write_rds(svm_tune, here('code', 'rds', 'svm_tune.rds'))
svm_tune <- read_rds(here('code', 'rds', 'svm_tune.rds'))
# extract the best svm
best_svm <- svm_tune %>%
select_best("roc_auc")
# print metrics
(svm_metrics <- svm_tune %>%
collect_metrics() %>%
semi_join(best_svm) %>%
select(.metric:.config) %>%
mutate(model = 'SVM'))
## # A tibble: 2 × 7
## .metric .estimator mean n std_err .config model
## <chr> <chr> <dbl> <int> <dbl> <chr> <chr>
## 1 roc_auc hand_till 0.805 4 0.00611 Preprocessor1_Model27 SVM
## 2 sens macro 0.615 4 0.0166 Preprocessor1_Model27 SVM
# finalize the wf with the best svm
svm_workflow <-
svm_poly_workflow %>%
finalize_workflow(best_svm)
# generate predictions on the hold-out test data
svm_auc <-
svm_tune %>%
collect_predictions(parameters = best_svm) %>%
roc_curve(.pred_Debris:`.pred_Single worm`, truth = worm) %>%
mutate(model = "SVM")
svm_auc %>%
autoplot()
xg_grid <- grid_latin_hypercube(
tree_depth(),
min_n(),
loss_reduction(),
sample_size = sample_prop(),
finalize(mtry(), train_data),
learn_rate(),
stop_iter(),
size = 30
)
# tune on the train data
xg_tune <-
xg_workflow %>%
tune_grid(
resamples = folds,
grid = xg_grid,
control = control_grid(save_pred = TRUE,
verbose = TRUE),
metrics = metric_set(roc_auc, sens)
)
write_rds(xg_tune, here('code', 'rds', 'xg_tune.rds'))
xg_tune <- read_rds(here('code', 'rds', 'xg_tune.rds'))
# extract the best decision tree
best_xg <- xg_tune %>%
select_best("roc_auc")
# print metrics
(xg_metrics <- xg_tune %>%
collect_metrics() %>%
semi_join(best_xg) %>%
select(.metric:.config) %>%
mutate(model = 'XGBoost'))
## # A tibble: 2 × 7
## .metric .estimator mean n std_err .config model
## <chr> <chr> <dbl> <int> <dbl> <chr> <chr>
## 1 roc_auc hand_till 0.841 10 0.00973 Preprocessor1_Model10 XGBoost
## 2 sens macro 0.601 10 0.0175 Preprocessor1_Model10 XGBoost
# finalize the wf with the best tree
xg_workflow <-
xg_workflow %>%
finalize_workflow(best_xg)
# generate predictions on the hold-out test data
xg_auc <-
xg_tune %>%
collect_predictions(parameters = best_xg) %>%
roc_curve(.pred_Debris:`.pred_Single worm`, truth = worm) %>%
mutate(model = "XGBoost")
xg_auc %>%
autoplot()
Evaluate using ROC AUC.
(all_metrics <- bind_rows(dt_metrics, mn_metrics, rf_metrics, svm_metrics, xg_metrics) %>%
group_by(.metric) %>%
arrange(-mean))
## # A tibble: 10 × 7
## # Groups: .metric [2]
## .metric .estimator mean n std_err .config model
## <chr> <chr> <dbl> <int> <dbl> <chr> <chr>
## 1 roc_auc hand_till 0.845 10 0.0110 Preprocessor1_Model7 Random forest
## 2 roc_auc hand_till 0.841 10 0.00973 Preprocessor1_Model10 XGBoost
## 3 roc_auc hand_till 0.813 10 0.0113 Preprocessor1_Model114 Decision tree
## 4 roc_auc hand_till 0.805 4 0.00611 Preprocessor1_Model27 SVM
## 5 roc_auc hand_till 0.790 10 0.00847 Preprocessor1_Model7 Multinomial re…
## 6 sens macro 0.615 4 0.0166 Preprocessor1_Model27 SVM
## 7 sens macro 0.605 10 0.0170 Preprocessor1_Model7 Random forest
## 8 sens macro 0.601 10 0.0175 Preprocessor1_Model10 XGBoost
## 9 sens macro 0.586 10 0.0168 Preprocessor1_Model114 Decision tree
## 10 sens macro 0.537 10 0.0174 Preprocessor1_Model7 Multinomial re…
(all_models <- bind_rows(dt_auc, mn_auc, rf_auc, svm_auc, xg_auc) %>%
ggplot(aes(x = 1 - specificity, y = sensitivity, col = model)) +
geom_path(lwd = 1.5, alpha = 0.8) +
geom_abline(lty = 3) +
coord_equal() +
scale_color_viridis_d(option = "plasma") +
facet_wrap(facets = vars(.level)) +
theme_minimal() +
NULL)
Random forest and XGBoost consistently perform the best across all 4 classes. Now I fit to the test data using the best parameters and evaluate the model’s performance.
mtry <- best_xg$mtry
trees <- 1000
min_n <- best_xg$min_n
tree_depth <- best_xg$tree_depth
learn_rate <- best_xg$learn_rate
loss_reduction <- best_xg$loss_reduction
last_mod <-
boost_tree(mtry = mtry,
trees = trees,
min_n = min_n,
tree_depth = tree_depth,
learn_rate = learn_rate,
loss_reduction = loss_reduction) %>%
set_engine("xgboost", importance = "impurity") %>%
set_mode("classification")
last_workflow <-
xg_workflow %>%
update_model(last_mod)
set.seed(345)
last_fit <-
last_workflow %>%
last_fit(data_split,
metrics = metric_set(roc_auc, sens))
collect_metrics(last_fit)
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 sens macro 0.590 Preprocessor1_Model1
## 2 roc_auc hand_till 0.832 Preprocessor1_Model1
last_fit %>%
extract_fit_engine() %>%
vip::vip() +
theme_minimal()
(final_auc <-
last_fit %>%
collect_predictions() %>%
roc_curve(.pred_Debris:`.pred_Single worm`, truth = worm) %>%
autoplot())
last_fit %>%
collect_predictions() %>%
conf_mat(truth = worm, estimate = .pred_class) %>%
autoplot()
The model actually performs better on the test data than the training data, indicating that we aren’t overfitting.
In a situation where we probably have more data points than are truly necessary to be able to draw defensible inferences, we are most concerned with accurate identification of a Single Worm. By that I mean that we are ok if the false negative rate is high (i.e., a Single Worm is identified as either Debris, Partial, or Multiple). Thus, we want a high true positive and low false positive for Single Worms, or high positive predictive value (PPV) and high Sensitivity. Using the selected model and the test set, here’s what would happen if we only kept the StraightenedWorms to be predicted as a Single Worm:
last_fit %>%
collect_predictions() %>%
filter(.pred_class == 'Single worm') %>%
conf_mat(truth = worm, estimate = .pred_class)
## Truth
## Prediction Debris Multiple worms Partial worm Single worm
## Debris 0 0 0 0
## Multiple worms 0 0 0 0
## Partial worm 0 0 0 0
## Single worm 31 39 38 598
last_fit %>%
collect_predictions() %>%
group_by(worm) %>%
summarise(n())
## # A tibble: 4 × 2
## worm `n()`
## <fct> <int>
## 1 Debris 155
## 2 Multiple worms 98
## 3 Partial worm 137
## 4 Single worm 817
final_wf <- last_fit %>%
extract_workflow()
write_rds(final_wf, here('code', 'rds', 'final_workflow.rds'))
pre_filter <- annotations %>%
select(worm, area_shape_major_axis_length) %>%
ggplot(aes(x = worm, y = area_shape_major_axis_length)) +
geom_quasirandom(aes(color = worm)) +
geom_text(data = . %>% group_by(worm) %>% summarise(n = n()),
aes(label = n), y = 550) +
theme_minimal() +
labs(title = 'Pre-filter') +
lims(y = c(0, 600)) +
theme(legend.position = 'empty')
post_filter <- augment(final_wf, annotations) %>%
filter(.pred_class == 'Single worm') %>%
select(worm, area_shape_major_axis_length) %>%
ggplot(aes(x = worm, y = area_shape_major_axis_length)) +
geom_quasirandom(aes(color = worm)) +
geom_text(data = . %>% group_by(worm) %>% summarise(n = n()),
aes(label = n), y = 550) +
theme_minimal() +
labs(title = 'Post-filter') +
lims(y = c(0, 600)) +
theme(legend.position = 'empty')
cowplot::plot_grid(pre_filter, post_filter, nrow = 1, align = 'h', axis = 'tb')
(percent_loss <- annotations %>%
select(worm, area_shape_major_axis_length) %>%
group_by(worm) %>%
summarise(pre_filter = n()) %>%
left_join(
augment(final_wf, annotations) %>%
filter(.pred_class == 'Single worm') %>%
select(worm, area_shape_major_axis_length) %>%
group_by(worm) %>%
summarise(post_filter = n())
) %>%
mutate(percent_loss = 1 - post_filter / pre_filter))
## # A tibble: 4 × 4
## worm pre_filter post_filter percent_loss
## <fct> <int> <int> <dbl>
## 1 Debris 589 63 0.893
## 2 Multiple worms 366 71 0.806
## 3 Partial worm 583 94 0.839
## 4 Single worm 3283 2696 0.179